{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Using extra features and descriptors\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook demonstrates how to use extra features and descriptors in addition to the default Chemprop featurizers.\n",
    "\n",
    "* Extra atom and bond features are used in addition to those calculated by Chemprop internally. \n",
    "* Extra atom descriptors get incorporated into the atom descriptors from message passing via a learned linear transformation. \n",
    "* Extra bond descriptors are not currently supported because the bond descriptors from message passing are not used for molecular property prediction. \n",
    "* Extra molecule features can be used as extra datapoint descriptors, which are concatenated to the output of the aggregation layer before the final prediction layer."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chemprop/chemprop/blob/main/examples/extra_features_descriptors.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install chemprop from GitHub if running in Google Colab\n",
    "import os\n",
    "\n",
    "if os.getenv(\"COLAB_RELEASE_TAG\"):\n",
    "    try:\n",
    "        import chemprop\n",
    "    except ImportError:\n",
    "        !git clone https://github.com/chemprop/chemprop.git\n",
    "        %cd chemprop\n",
    "        !pip install .\n",
    "        %cd examples"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Loading packages and data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from pathlib import Path\n",
    "\n",
    "from lightning import pytorch as pl\n",
    "from rdkit import Chem\n",
    "\n",
    "from chemprop import data, featurizers, models, nn, utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "chemprop_dir = Path.cwd().parent\n",
    "input_path = chemprop_dir / \"tests\" / \"data\" / \"regression\" / \"mol\" / \"mol.csv\"\n",
    "smiles_column = \"smiles\"\n",
    "target_columns = [\"lipo\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_input = pd.read_csv(input_path)\n",
    "smis = df_input.loc[:, smiles_column].values\n",
    "ys = df_input.loc[:, target_columns].values"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Getting extra features and descriptors"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `rdkit.Chem.Mol` representation of molecules is needed as input to many featurizers. Chemprop provides a helpful wrapper to rdkit to make these from SMILES."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "mols = [utils.make_mol(smi, keep_h=False, add_h=False) for smi in smis]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Extra atom features, atom descriptors, bond features"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Extra atom and bond features frequently come from QM calculations. The calculation results can be saved to a file and then loaded in a notebook using pandas or numpy. The loaded atom or bond features can be a list of numpy arrays where each numpy array of features corresponds to a single molecule in the dataset. Each row in an array corresponds to a different atom or bond in the same order of atoms or bonds in the `rdkit.Chem.Mol` objects. \n",
    "\n",
    "The atom features could also be used as extra atom descriptors."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This code is just a placeholder for the actual QM calculation\n",
    "\n",
    "\n",
    "def QM_calculation(mol):\n",
    "    n_extra_atom_feats = 10\n",
    "    n_extra_bond_feats = 4\n",
    "    extra_atom_features = np.array([np.random.randn(n_extra_atom_feats) for a in mol.GetAtoms()])\n",
    "    extra_bond_features = np.array([np.random.randn(n_extra_bond_feats) for a in mol.GetBonds()])\n",
    "    return extra_atom_features, extra_bond_features\n",
    "\n",
    "\n",
    "extra_atom_featuress = []\n",
    "extra_bond_featuress = []\n",
    "\n",
    "for mol in mols:\n",
    "    extra_atom_features, extra_bond_features = QM_calculation(mol)\n",
    "    extra_atom_featuress.append(extra_atom_features)\n",
    "    extra_bond_featuress.append(extra_bond_features)\n",
    "\n",
    "# Save to a file\n",
    "np.savez(\"atom_features.npz\", *extra_atom_featuress)\n",
    "np.savez(\"bond_features.npz\", *extra_bond_featuress)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "extra_atom_featuress = np.load(\"atom_features.npz\")\n",
    "extra_atom_featuress = [extra_atom_featuress[f\"arr_{i}\"] for i in range(len(extra_atom_featuress))]\n",
    "\n",
    "extra_atom_descriptorss = extra_atom_featuress\n",
    "\n",
    "extra_bond_featuress = np.load(\"bond_features.npz\")\n",
    "extra_bond_featuress = [extra_bond_featuress[f\"arr_{i}\"] for i in range(len(extra_bond_featuress))]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can also get extra atom and bond features from other sources."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "atom_radii = {1: 0.79, 5: 1.2, 6: 0.91, 7: 0.75, 8: 0.65, 9: 0.57, 16: 1.1, 17: 0.97, 35: 1.1}\n",
    "\n",
    "extra_atom_featuress = [\n",
    "    np.vstack([np.array([[atom_radii[a.GetAtomicNum()]] for a in mol.GetAtoms()])]) for mol in mols\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Extra molecule features"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A QM calculation could also be used to get extra molecule features. Extra molecule features are different from extra atom and bond features in that they are stored in a single numpy array where each row corresponds to a single molecule in the dataset, instead of a list of numpy arrays."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def QM_calculation(mol):\n",
    "    n_extra_mol_feats = 7\n",
    "    return np.random.randn(n_extra_mol_feats)\n",
    "\n",
    "\n",
    "extra_mol_features = np.array([QM_calculation(mol) for mol in mols])\n",
    "\n",
    "np.savez(\"mol_features.npz\", extra_mol_features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "extra_mol_features = np.load(\"mol_features.npz\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The extra molecule features can also be calculated using built-in Chemprop featurizers or featurizers from other packages."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "molecule_featurizer = featurizers.MorganBinaryFeaturizer()\n",
    "\n",
    "extra_mol_features = np.array([molecule_featurizer(mol) for mol in mols])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# First install other package\n",
    "# !pip install descriptastorus\n",
    "\n",
    "# from descriptastorus.descriptors import rdNormalizedDescriptors\n",
    "# generator = rdNormalizedDescriptors.RDKit2DNormalized()\n",
    "# extra_mol_features = np.array([generator.process(smi)[1:] for smi in smis])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The molecule featurizers available in Chemprop are registered in `MoleculeFeaturizerRegristry`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "morgan_binary\n",
      "morgan_count\n",
      "rdkit_2d\n",
      "v1_rdkit_2d\n",
      "v1_rdkit_2d_normalized\n"
     ]
    }
   ],
   "source": [
    "for MoleculeFeaturizer in featurizers.MoleculeFeaturizerRegistry.keys():\n",
    "    print(MoleculeFeaturizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If your model takes multiple components as input, you can use extra molecule features for each component as extra datapoint descriptors. Simply concatentate the extra molecule features together before passing them to the datapoints."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "extra_mol_features_comp1 = np.random.rand(len(mols), 5)\n",
    "extra_mol_features_comp2 = np.random.rand(len(mols), 5)\n",
    "\n",
    "extra_datapoint_descriptors = np.concatenate(\n",
    "    [extra_mol_features_comp1, extra_mol_features_comp2], axis=1\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Making datapoints, datasets, and dataloaders"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once you have all the extra features and descriptors your model will use, you can make the datapoints."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "datapoints = [\n",
    "    data.MoleculeDatapoint(mol, y, V_f=V_f, E_f=E_f, V_d=V_d, x_d=X_d)\n",
    "    for mol, y, V_f, E_f, V_d, X_d in zip(\n",
    "        mols,\n",
    "        ys,\n",
    "        extra_atom_featuress,\n",
    "        extra_bond_featuress,\n",
    "        extra_atom_descriptorss,\n",
    "        extra_datapoint_descriptors,\n",
    "    )\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After splitting the data, the datasets are made. To make a dataset, you need a `MolGraph` featurizer, which needs to be told the size of extra atom and bond features. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_extra_atom_feats = extra_atom_featuress[0].shape[1]\n",
    "n_extra_bond_feats = extra_bond_featuress[0].shape[1]\n",
    "\n",
    "featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer(\n",
    "    extra_atom_fdim=n_extra_atom_feats, extra_bond_fdim=n_extra_bond_feats\n",
    ")\n",
    "\n",
    "train_indices, val_indices, test_indices = data.make_split_indices(mols, \"random\", (0.8, 0.1, 0.1))\n",
    "train_data, val_data, test_data = data.split_data_by_indices(\n",
    "    datapoints, train_indices, val_indices, test_indices\n",
    ")\n",
    "\n",
    "train_dset = data.MoleculeDataset(train_data[0], featurizer)\n",
    "val_dset = data.MoleculeDataset(val_data[0], featurizer)\n",
    "test_dset = data.MoleculeDataset(test_data[0], featurizer)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Often scaling the extra features and descriptors improves model performance. The scalers for the extra features and descriptors should be fit to the training dataset, applied to the validation dataset, and then given to the model to apply to the test dataset at prediction time. This is the same as for scaling target values to improve model performance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-1 {\n",
       "  /* Definition of color scheme common for light and dark mode */\n",
       "  --sklearn-color-text: black;\n",
       "  --sklearn-color-line: gray;\n",
       "  /* Definition of color scheme for unfitted estimators */\n",
       "  --sklearn-color-unfitted-level-0: #fff5e6;\n",
       "  --sklearn-color-unfitted-level-1: #f6e4d2;\n",
       "  --sklearn-color-unfitted-level-2: #ffe0b3;\n",
       "  --sklearn-color-unfitted-level-3: chocolate;\n",
       "  /* Definition of color scheme for fitted estimators */\n",
       "  --sklearn-color-fitted-level-0: #f0f8ff;\n",
       "  --sklearn-color-fitted-level-1: #d4ebff;\n",
       "  --sklearn-color-fitted-level-2: #b3dbfd;\n",
       "  --sklearn-color-fitted-level-3: cornflowerblue;\n",
       "\n",
       "  /* Specific color for light theme */\n",
       "  --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
       "  --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
       "  --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
       "  --sklearn-color-icon: #696969;\n",
       "\n",
       "  @media (prefers-color-scheme: dark) {\n",
       "    /* Redefinition of color scheme for dark theme */\n",
       "    --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
       "    --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
       "    --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
       "    --sklearn-color-icon: #878787;\n",
       "  }\n",
       "}\n",
       "\n",
       "#sk-container-id-1 {\n",
       "  color: var(--sklearn-color-text);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 pre {\n",
       "  padding: 0;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 input.sk-hidden--visually {\n",
       "  border: 0;\n",
       "  clip: rect(1px 1px 1px 1px);\n",
       "  clip: rect(1px, 1px, 1px, 1px);\n",
       "  height: 1px;\n",
       "  margin: -1px;\n",
       "  overflow: hidden;\n",
       "  padding: 0;\n",
       "  position: absolute;\n",
       "  width: 1px;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-dashed-wrapped {\n",
       "  border: 1px dashed var(--sklearn-color-line);\n",
       "  margin: 0 0.4em 0.5em 0.4em;\n",
       "  box-sizing: border-box;\n",
       "  padding-bottom: 0.4em;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-container {\n",
       "  /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
       "     but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
       "     so we also need the `!important` here to be able to override the\n",
       "     default hidden behavior on the sphinx rendered scikit-learn.org.\n",
       "     See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
       "  display: inline-block !important;\n",
       "  position: relative;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-text-repr-fallback {\n",
       "  display: none;\n",
       "}\n",
       "\n",
       "div.sk-parallel-item,\n",
       "div.sk-serial,\n",
       "div.sk-item {\n",
       "  /* draw centered vertical line to link estimators */\n",
       "  background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
       "  background-size: 2px 100%;\n",
       "  background-repeat: no-repeat;\n",
       "  background-position: center center;\n",
       "}\n",
       "\n",
       "/* Parallel-specific style estimator block */\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel-item::after {\n",
       "  content: \"\";\n",
       "  width: 100%;\n",
       "  border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
       "  flex-grow: 1;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel {\n",
       "  display: flex;\n",
       "  align-items: stretch;\n",
       "  justify-content: center;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  position: relative;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel-item {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel-item:first-child::after {\n",
       "  align-self: flex-end;\n",
       "  width: 50%;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel-item:last-child::after {\n",
       "  align-self: flex-start;\n",
       "  width: 50%;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel-item:only-child::after {\n",
       "  width: 0;\n",
       "}\n",
       "\n",
       "/* Serial-specific style estimator block */\n",
       "\n",
       "#sk-container-id-1 div.sk-serial {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "  align-items: center;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  padding-right: 1em;\n",
       "  padding-left: 1em;\n",
       "}\n",
       "\n",
       "\n",
       "/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
       "clickable and can be expanded/collapsed.\n",
       "- Pipeline and ColumnTransformer use this feature and define the default style\n",
       "- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
       "*/\n",
       "\n",
       "/* Pipeline and ColumnTransformer style (default) */\n",
       "\n",
       "#sk-container-id-1 div.sk-toggleable {\n",
       "  /* Default theme specific background. It is overwritten whether we have a\n",
       "  specific estimator or a Pipeline/ColumnTransformer */\n",
       "  background-color: var(--sklearn-color-background);\n",
       "}\n",
       "\n",
       "/* Toggleable label */\n",
       "#sk-container-id-1 label.sk-toggleable__label {\n",
       "  cursor: pointer;\n",
       "  display: block;\n",
       "  width: 100%;\n",
       "  margin-bottom: 0;\n",
       "  padding: 0.5em;\n",
       "  box-sizing: border-box;\n",
       "  text-align: center;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 label.sk-toggleable__label-arrow:before {\n",
       "  /* Arrow on the left of the label */\n",
       "  content: \"▸\";\n",
       "  float: left;\n",
       "  margin-right: 0.25em;\n",
       "  color: var(--sklearn-color-icon);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {\n",
       "  color: var(--sklearn-color-text);\n",
       "}\n",
       "\n",
       "/* Toggleable content - dropdown */\n",
       "\n",
       "#sk-container-id-1 div.sk-toggleable__content {\n",
       "  max-height: 0;\n",
       "  max-width: 0;\n",
       "  overflow: hidden;\n",
       "  text-align: left;\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-toggleable__content.fitted {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-toggleable__content pre {\n",
       "  margin: 0.2em;\n",
       "  border-radius: 0.25em;\n",
       "  color: var(--sklearn-color-text);\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-toggleable__content.fitted pre {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
       "  /* Expand drop-down */\n",
       "  max-height: 200px;\n",
       "  max-width: 100%;\n",
       "  overflow: auto;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
       "  content: \"▾\";\n",
       "}\n",
       "\n",
       "/* Pipeline/ColumnTransformer-specific style */\n",
       "\n",
       "#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Estimator-specific style */\n",
       "\n",
       "/* Colorize estimator box */\n",
       "#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-label label.sk-toggleable__label,\n",
       "#sk-container-id-1 div.sk-label label {\n",
       "  /* The background is the default theme color */\n",
       "  color: var(--sklearn-color-text-on-default-background);\n",
       "}\n",
       "\n",
       "/* On hover, darken the color of the background */\n",
       "#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "/* Label box, darken color on hover, fitted */\n",
       "#sk-container-id-1 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Estimator label */\n",
       "\n",
       "#sk-container-id-1 div.sk-label label {\n",
       "  font-family: monospace;\n",
       "  font-weight: bold;\n",
       "  display: inline-block;\n",
       "  line-height: 1.2em;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-label-container {\n",
       "  text-align: center;\n",
       "}\n",
       "\n",
       "/* Estimator-specific */\n",
       "#sk-container-id-1 div.sk-estimator {\n",
       "  font-family: monospace;\n",
       "  border: 1px dotted var(--sklearn-color-border-box);\n",
       "  border-radius: 0.25em;\n",
       "  box-sizing: border-box;\n",
       "  margin-bottom: 0.5em;\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-estimator.fitted {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "/* on hover */\n",
       "#sk-container-id-1 div.sk-estimator:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-estimator.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
       "\n",
       "/* Common style for \"i\" and \"?\" */\n",
       "\n",
       ".sk-estimator-doc-link,\n",
       "a:link.sk-estimator-doc-link,\n",
       "a:visited.sk-estimator-doc-link {\n",
       "  float: right;\n",
       "  font-size: smaller;\n",
       "  line-height: 1em;\n",
       "  font-family: monospace;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  border-radius: 1em;\n",
       "  height: 1em;\n",
       "  width: 1em;\n",
       "  text-decoration: none !important;\n",
       "  margin-left: 1ex;\n",
       "  /* unfitted */\n",
       "  border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-unfitted-level-1);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link.fitted,\n",
       "a:link.sk-estimator-doc-link.fitted,\n",
       "a:visited.sk-estimator-doc-link.fitted {\n",
       "  /* fitted */\n",
       "  border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-fitted-level-1);\n",
       "}\n",
       "\n",
       "/* On hover */\n",
       "div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
       ".sk-estimator-doc-link:hover,\n",
       "div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
       ".sk-estimator-doc-link:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
       ".sk-estimator-doc-link.fitted:hover,\n",
       "div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
       ".sk-estimator-doc-link.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "/* Span, style for the box shown on hovering the info icon */\n",
       ".sk-estimator-doc-link span {\n",
       "  display: none;\n",
       "  z-index: 9999;\n",
       "  position: relative;\n",
       "  font-weight: normal;\n",
       "  right: .2ex;\n",
       "  padding: .5ex;\n",
       "  margin: .5ex;\n",
       "  width: min-content;\n",
       "  min-width: 20ex;\n",
       "  max-width: 50ex;\n",
       "  color: var(--sklearn-color-text);\n",
       "  box-shadow: 2pt 2pt 4pt #999;\n",
       "  /* unfitted */\n",
       "  background: var(--sklearn-color-unfitted-level-0);\n",
       "  border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link.fitted span {\n",
       "  /* fitted */\n",
       "  background: var(--sklearn-color-fitted-level-0);\n",
       "  border: var(--sklearn-color-fitted-level-3);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link:hover span {\n",
       "  display: block;\n",
       "}\n",
       "\n",
       "/* \"?\"-specific style due to the `<a>` HTML tag */\n",
       "\n",
       "#sk-container-id-1 a.estimator_doc_link {\n",
       "  float: right;\n",
       "  font-size: 1rem;\n",
       "  line-height: 1em;\n",
       "  font-family: monospace;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  border-radius: 1rem;\n",
       "  height: 1rem;\n",
       "  width: 1rem;\n",
       "  text-decoration: none;\n",
       "  /* unfitted */\n",
       "  color: var(--sklearn-color-unfitted-level-1);\n",
       "  border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 a.estimator_doc_link.fitted {\n",
       "  /* fitted */\n",
       "  border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-fitted-level-1);\n",
       "}\n",
       "\n",
       "/* On hover */\n",
       "#sk-container-id-1 a.estimator_doc_link:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 a.estimator_doc_link.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-3);\n",
       "}\n",
       "</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>StandardScaler()</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;&nbsp;StandardScaler<a class=\"sk-estimator-doc-link fitted\" rel=\"noreferrer\" target=\"_blank\" href=\"https://scikit-learn.org/1.4/modules/generated/sklearn.preprocessing.StandardScaler.html\">?<span>Documentation for StandardScaler</span></a><span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>StandardScaler()</pre></div> </div></div></div></div>"
      ],
      "text/plain": [
       "StandardScaler()"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "targets_scaler = train_dset.normalize_targets()\n",
    "extra_atom_features_scaler = train_dset.normalize_inputs(\"V_f\")\n",
    "extra_bond_features_scaler = train_dset.normalize_inputs(\"E_f\")\n",
    "extra_atom_descriptors_scaler = train_dset.normalize_inputs(\"V_d\")\n",
    "extra_datapoint_descriptors_scaler = train_dset.normalize_inputs(\"X_d\")\n",
    "\n",
    "val_dset.normalize_targets(targets_scaler)\n",
    "val_dset.normalize_inputs(\"V_f\", extra_atom_features_scaler)\n",
    "val_dset.normalize_inputs(\"E_f\", extra_bond_features_scaler)\n",
    "val_dset.normalize_inputs(\"V_d\", extra_atom_descriptors_scaler)\n",
    "val_dset.normalize_inputs(\"X_d\", extra_datapoint_descriptors_scaler)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Featurize the train and val datasets to save computation time.\n",
    "train_dset.cache = True\n",
    "val_dset.cache = True\n",
    "\n",
    "train_loader = data.build_dataloader(train_dset)\n",
    "val_loader = data.build_dataloader(val_dset, shuffle=False)\n",
    "test_loader = data.build_dataloader(test_dset, shuffle=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Making the model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The message passing layer needs to know the total size of atom and bond features (i.e. the sum of the sizes of the Chemprop atom and bond features and the extra atom and bond features). The `MolGraph` featurizer collects this information. The message passing layer also needs to know the number of extra atom descriptors.\n",
    "\n",
    "The extra atom and bond features scalers are combined into a graph transform which is given to the message passing layer to use at prediction time. To avoid scaling the atom and bond features from the internal Chemprop featurizers, the graph transform uses a pad equal to the length of features from the Chemprop internal atom and bond featurizers. This information is stored in the `MolGraph` featurizer.\n",
    "\n",
    "The extra atom descriptor scaler are also converted to a transform and given to the message passing layer to use at prediction time."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_V_features = featurizer.atom_fdim - featurizer.extra_atom_fdim\n",
    "n_E_features = featurizer.bond_fdim - featurizer.extra_bond_fdim\n",
    "\n",
    "V_f_transform = nn.ScaleTransform.from_standard_scaler(extra_atom_features_scaler, pad=n_V_features)\n",
    "E_f_transform = nn.ScaleTransform.from_standard_scaler(extra_bond_features_scaler, pad=n_E_features)\n",
    "\n",
    "graph_transform = nn.GraphTransform(V_f_transform, E_f_transform)\n",
    "\n",
    "V_d_transform = nn.ScaleTransform.from_standard_scaler(extra_atom_descriptors_scaler)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_extra_atom_descs = extra_atom_descriptorss[0].shape[1]\n",
    "\n",
    "mp = nn.BondMessagePassing(\n",
    "    d_v=featurizer.atom_fdim,\n",
    "    d_e=featurizer.bond_fdim,\n",
    "    d_vd=n_extra_atom_descs,\n",
    "    graph_transform=graph_transform,\n",
    "    V_d_transform=V_d_transform,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The predictor layer needs to know the size of the its input, including any extra datapoint descriptors. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "ffn_input_dim = mp.output_dim + extra_datapoint_descriptors.shape[1]\n",
    "\n",
    "output_transform = nn.UnscaleTransform.from_standard_scaler(targets_scaler)\n",
    "ffn = nn.RegressionFFN(input_dim=ffn_input_dim, output_transform=output_transform)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The overall model is given the transform from the extra datapoint descriptors scaler."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_d_transform = nn.ScaleTransform.from_standard_scaler(extra_datapoint_descriptors_scaler)\n",
    "\n",
    "chemprop_model = models.MPNN(mp, nn.NormAggregation(), ffn, X_d_transform=X_d_transform)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training and prediction"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The rest of the training and prediction are the same as other Chemprop workflows."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: False, used: False\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    }
   ],
   "source": [
    "trainer = pl.Trainer(\n",
    "    logger=False, enable_checkpointing=False, enable_progress_bar=True, max_epochs=5\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading `train_dataloader` to estimate number of stepping batches.\n",
      "/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n",
      "\n",
      "  | Name            | Type               | Params | Mode \n",
      "---------------------------------------------------------------\n",
      "0 | message_passing | BondMessagePassing | 325 K  | train\n",
      "1 | agg             | NormAggregation    | 0      | train\n",
      "2 | bn              | Identity           | 0      | train\n",
      "3 | predictor       | RegressionFFN      | 96.6 K | train\n",
      "4 | X_d_transform   | ScaleTransform     | 0      | train\n",
      "5 | metrics         | ModuleList         | 0      | train\n",
      "---------------------------------------------------------------\n",
      "422 K     Trainable params\n",
      "0         Non-trainable params\n",
      "422 K     Total params\n",
      "1.690     Total estimated model params size (MB)\n",
      "27        Modules in train mode\n",
      "0         Modules in eval mode\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 4: 100%|██████████| 2/2 [00:00<00:00,  2.22it/s, train_loss_step=1.340, val_loss=8.520, train_loss_epoch=0.936]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=5` reached.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 4: 100%|██████████| 2/2 [00:00<00:00,  2.21it/s, train_loss_step=1.340, val_loss=8.520, train_loss_epoch=0.936]\n"
     ]
    }
   ],
   "source": [
    "trainer.fit(chemprop_model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 11.20it/s]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test/mse          </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.9464761018753052     </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test/mse         \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.9464761018753052    \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "results = trainer.test(chemprop_model, test_loader)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "chemprop",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
